#include <gtest/gtest.h>
#include <gmock/gmock.h>

#define private public
#define protected public
#include <element.h>
#undef private
#undef protected
#include <iostream>

using namespace std;
using ::testing::ContainerEq;

// helper functions
inline double euclideanDistance(const StdArray6d &pi, const StdArray6d &pj)
{
    double ans = 0.0;
    for (int i = 0; i < 3; i++)
    {
        ans += (pi[i] - pj[i]) * (pi[i] - pj[i]);
    }
    return sqrt(ans);
};

inline double euclideanDistance(const StdArray6d* pi, const StdArray6d* pj)
{
    double ans = 0.0;
    for (int i = 0; i < 3; i++)
    {
        ans += ((*pi)[i] - (*pj)[i]) * ((*pi)[i] - (*pj)[i]);
    }
    return sqrt(ans);
};

inline void printStdArray6d(const StdArray6d &arr)
{
    for (int i = 0; i < 6; i++)
    {
        cout << arr[i] << ", ";
    }
    cout << endl;
}

inline void printStdArray6d(const StdArray6d* arr)
{
    for (int i = 0; i < 6; i++)
    {
        cout << (*arr)[i] << ", ";
    }
    cout << endl;
}

class Beam2DTest: public::testing::Test
{
    protected:
        // Declares the Beam2Ds
        // Declares particles
        Particle* p1;
        Particle* p2;

        Particle* p3;
        Particle* p4;

        Particle* p5;
        Particle* p6;

        Particle* p7;
        Particle* p8;
        Particle* p9;

        // Declares materials
        double E , Et, density;
        LinearElastic* mat1;
        UniBilinear* mat2;

        // Declares sections
        Rectangle* sect1;
        AngleSteel* sect2;

        // Declares elements
        Beam2D* e1;
        Beam2D* e2;
        Beam2D* e3;

        Beam2D* e4;
        Beam2D* e5;

        // other data

        // You can remove any or all of the following functions if their bodies
        // would be empty.
        Beam2DTest()
        {
            // You can do set-up work for each test here.
            // prepare particles
            p1 = new Particle(1, 0, 0, 0);
            p2 = new Particle(2, 10, 0, 0);

            p3 = new Particle(3, 0, 0, 0);
            p4 = new Particle(4, 10, 0, 0);

            p5 = new Particle(5, 0, 0, 0);
            p6 = new Particle(6, 10, 0, 0);

            p7 = new Particle(7, 0, 0, 0);
            p8 = new Particle(8, 10, 0, 0);
            p9 = new Particle(9, 10, 10, 0);

            // prepare materials
            E = 1e3, Et=1e2, density = 10;
            mat1 = new LinearElastic(1);
            mat1->setE(E);
            mat1->setDensity(density);

            mat2 = new UniBilinear(2, E, Et, 20, 0);
            mat2->setDensity(density);

            // prepare sections
            sect1 = new Rectangle(1, 2, 4);
            sect2 = new AngleSteel(2, 10, 10, 2, 2);

            // prepare elements
            e1 = new Beam2D(1, p1, p2, mat1, sect1);
            e2 = new Beam2D(2, p3, p4, mat2, sect1);
            e3 = new Beam2D(3, p5, p6, mat1, sect2);

            e4 = new Beam2D(4, p7, p8, mat1, sect1);
            e5 = new Beam2D(5, p7, p9, mat1, sect1);
        };

        ~Beam2DTest()
        {
            // You can do clean-up work that doesn't throw exceptions here.
            // delete particles
            delete p1, p2, p3, p4, p5, p6, p7, p8, p9;

            // delete materials
            delete mat1, mat2;

            // delete sections
            delete sect1, sect2;

            // delete elements
            delete e1, e2, e3, e4, e5;
        };

        // If the constructor and destructor are not enough for setting up
        // and cleaning up each test, you can define the following methods:
        void SetUp() override
        {
            // Code here will be called immediately after the constructor
            // (right before each test).
            cout << "Set Up Beam2D Test" << endl;
        };

        void TearDown() override
        {
            // Code here will be called immediately after each test (right
            // before the destructor).
        };

        // Class members declared here can be used by all tests in the test
        // suite for Foo.
};



// Tests that the Beam2D::Beam2D() method does Abc.
TEST_F(Beam2DTest, IsEmptyInitiate)
{
    cout << "\n---------- Beam2DTest IsEmptyInitiate ----------" << endl;
    // initiated by BaseElement
    cout << "check id... " << endl;
    EXPECT_EQ(e1->id(), 1);
    EXPECT_EQ(e2->id(), 2);
    EXPECT_EQ(e3->id(), 3);
    EXPECT_EQ(e4->id(), 4);
    EXPECT_EQ(e5->id(), 5);

    cout << "check type... " << endl;
    char type[] = "Beam2D";
    EXPECT_STREQ(type, e1->type_.c_str());

    cout << "check num_of_particles... " << endl;
    EXPECT_EQ(2, e1->num_of_particles_);

    cout << "check particles container... " << endl;
    auto iter = e1->particles_.begin();
    EXPECT_EQ(p1, iter->second);
    iter++;
    EXPECT_EQ(p2, iter->second);

    iter = e2->particles_.begin();
    EXPECT_EQ(p3, iter->second);
    iter++;
    EXPECT_EQ(p4, iter->second);

    iter = e3->particles_.begin();
    EXPECT_EQ(p5, iter->second);
    iter++;
    EXPECT_EQ(p6, iter->second);

    iter = e4->particles_.begin();
    EXPECT_EQ(p7, iter->second);
    iter++;
    EXPECT_EQ(p8, iter->second);

    iter = e5->particles_.begin();
    EXPECT_EQ(p7, iter->second);
    iter++;
    EXPECT_EQ(p9, iter->second);

    cout << "check material... " << endl;
    EXPECT_EQ(mat1, e1->material_);
    EXPECT_EQ(mat2, e2->material_);
    EXPECT_EQ(mat1, e3->material_);
    EXPECT_EQ(mat1, e4->material_);
    EXPECT_EQ(mat1, e5->material_);

    cout << "check section... " << endl;
    EXPECT_EQ(sect1, e1->section_);
    EXPECT_EQ(sect1, e2->section_);
    EXPECT_EQ(sect2, e3->section_);
    EXPECT_EQ(sect1, e4->section_);
    EXPECT_EQ(sect1, e5->section_);

    cout << "check mass... " << endl;
    double mass1 = density * e1->length_ * e1->section_->A();
    double mass2 = density * e2->length_ * e2->section_->A();
    double mass3 = density * e3->length_ * e3->section_->A();
    double mass4 = density * e4->length_ * e4->section_->A();
    double mass5 = density * e5->length_ * e5->section_->A();
    EXPECT_DOUBLE_EQ(mass1, e1->mass_);
    EXPECT_DOUBLE_EQ(mass2, e2->mass_);
    EXPECT_DOUBLE_EQ(mass3, e3->mass_);
    EXPECT_DOUBLE_EQ(mass4, e4->mass_);
    EXPECT_DOUBLE_EQ(mass5, e5->mass_);

    cout << "check strain... " << endl;
    EXPECT_DOUBLE_EQ(0.0, e1->strain_.sx);
    EXPECT_DOUBLE_EQ(0.0, e1->strain_.sy);
    EXPECT_DOUBLE_EQ(0.0, e1->strain_.sz);
    EXPECT_DOUBLE_EQ(0.0, e1->strain_.sxy);
    EXPECT_DOUBLE_EQ(0.0, e1->strain_.syz);
    EXPECT_DOUBLE_EQ(0.0, e1->strain_.sxz);

    cout << "check stress... " << endl;
    EXPECT_DOUBLE_EQ(0.0, e1->stress_.sx);
    EXPECT_DOUBLE_EQ(0.0, e1->stress_.sy);
    EXPECT_DOUBLE_EQ(0.0, e1->stress_.sz);
    EXPECT_DOUBLE_EQ(0.0, e1->stress_.sxy);
    EXPECT_DOUBLE_EQ(0.0, e1->stress_.syz);
    EXPECT_DOUBLE_EQ(0.0, e1->stress_.sxz);

    cout << "check ex, ey, ez... " << endl;
    // e1
    EXPECT_DOUBLE_EQ(1.0, e1->ex_(0));
    EXPECT_DOUBLE_EQ(0.0, e1->ex_(1));
    EXPECT_DOUBLE_EQ(0.0, e1->ex_(2));
    for (int i = 0; i < 3; i++)
    {
        EXPECT_DOUBLE_EQ(0.0, e1->ey_(i));
        EXPECT_DOUBLE_EQ(0.0, e1->ez_(i));
    }

    // e4
    EXPECT_DOUBLE_EQ(1.0 / sqrt(2), e5->ex_(0));
    EXPECT_DOUBLE_EQ(1.0 / sqrt(2), e5->ex_(1));
    EXPECT_DOUBLE_EQ(0.0, e5->ex_(2));
    for (int i = 0; i < 3; i++)
    {
        EXPECT_DOUBLE_EQ(0.0, e5->ey_(i));
        EXPECT_DOUBLE_EQ(0.0, e5->ez_(i));
    }

    cout << "check is_failed... " << endl;
    EXPECT_FALSE(e1->is_failed_);
    EXPECT_FALSE(e2->is_failed_);
    EXPECT_FALSE(e3->is_failed_);
    EXPECT_FALSE(e4->is_failed_);
    EXPECT_FALSE(e5->is_failed_);

    // initiated by StructElement
    cout << "check pi, pj... " << endl;
    EXPECT_EQ(p1, e1->pi_);
    EXPECT_EQ(p2, e1->pj_);
    EXPECT_EQ(p3, e2->pi_);
    EXPECT_EQ(p4, e2->pj_);
    EXPECT_EQ(p5, e3->pi_);
    EXPECT_EQ(p6, e3->pj_);
    EXPECT_EQ(p7, e4->pi_);
    EXPECT_EQ(p8, e4->pj_);
    EXPECT_EQ(p7, e5->pi_);
    EXPECT_EQ(p9, e5->pj_);

    cout << "check fi, fj... " << endl;
    StdArray6d target = {0, 0, 0, 0, 0, 0};
    EXPECT_THAT(target, ContainerEq(e1->fi_));
    EXPECT_THAT(target, ContainerEq(e2->fi_));
    EXPECT_THAT(target, ContainerEq(e3->fi_));
    EXPECT_THAT(target, ContainerEq(e4->fi_));
    EXPECT_THAT(target, ContainerEq(e1->fj_));
    EXPECT_THAT(target, ContainerEq(e2->fj_));
    EXPECT_THAT(target, ContainerEq(e3->fj_));
    EXPECT_THAT(target, ContainerEq(e4->fj_));

    cout << "check length... " << endl;
    EXPECT_DOUBLE_EQ(10.0, e1->length_);
    EXPECT_DOUBLE_EQ(10.0, e2->length_);
    EXPECT_DOUBLE_EQ(10.0, e3->length_);
    EXPECT_DOUBLE_EQ(10.0, e4->length_);
    EXPECT_DOUBLE_EQ(10.0 * sqrt(2), e5->length_);

    cout << "check lt_, la_... " << endl;
    EXPECT_DOUBLE_EQ(10.0, e1->lt_);
    EXPECT_DOUBLE_EQ(10.0, e2->lt_);
    EXPECT_DOUBLE_EQ(10.0, e3->lt_);
    EXPECT_DOUBLE_EQ(10.0, e4->lt_);
    EXPECT_DOUBLE_EQ(10.0 * sqrt(2), e5->lt_);
    EXPECT_DOUBLE_EQ(10.0, e1->la_);
    EXPECT_DOUBLE_EQ(10.0, e2->la_);
    EXPECT_DOUBLE_EQ(10.0, e3->la_);
    EXPECT_DOUBLE_EQ(10.0, e4->la_);
    EXPECT_DOUBLE_EQ(10.0 * sqrt(2), e5->la_);

    cout << "check element_force... " << endl;
    EXPECT_DOUBLE_EQ(0.0, e1->element_force_.fx);
    EXPECT_DOUBLE_EQ(0.0, e1->element_force_.sfy);
    EXPECT_DOUBLE_EQ(0.0, e1->element_force_.sfy);
    EXPECT_DOUBLE_EQ(0.0, e1->element_force_.mx);
    EXPECT_DOUBLE_EQ(0.0, e1->element_force_.my);
    EXPECT_DOUBLE_EQ(0.0, e1->element_force_.mz);

    // initiate by Beam2D
    cout << "check curr_posi, curr_posj, prev_posi, prev_posj... " << endl;
    EXPECT_EQ(p1->current_position(), e1->curr_posi_);
    EXPECT_EQ(p2->current_position(), e1->curr_posj_);
    EXPECT_EQ(p3->current_position(), e2->curr_posi_);
    EXPECT_EQ(p4->current_position(), e2->curr_posj_);
    EXPECT_EQ(p5->current_position(), e3->curr_posi_);
    EXPECT_EQ(p6->current_position(), e3->curr_posj_);
    EXPECT_EQ(p7->current_position(), e4->curr_posi_);
    EXPECT_EQ(p8->current_position(), e4->curr_posj_);
    EXPECT_EQ(p7->current_position(), e5->curr_posi_);
    EXPECT_EQ(p9->current_position(), e5->curr_posj_);
    EXPECT_EQ(p1->previous_position(), e1->prev_posi_);
    EXPECT_EQ(p2->previous_position(), e1->prev_posj_);
    EXPECT_EQ(p3->previous_position(), e2->prev_posi_);
    EXPECT_EQ(p4->previous_position(), e2->prev_posj_);
    EXPECT_EQ(p5->previous_position(), e3->prev_posi_);
    EXPECT_EQ(p6->previous_position(), e3->prev_posj_);
    EXPECT_EQ(p7->previous_position(), e4->prev_posi_);
    EXPECT_EQ(p8->previous_position(), e4->prev_posj_);
    EXPECT_EQ(p7->previous_position(), e5->prev_posi_);
    EXPECT_EQ(p9->previous_position(), e5->prev_posj_);

    cout << "check particle dof..." << endl;
    std::set<std::string> key;
    key.insert("Ux");
    key.insert("Uy");
    key.insert("Rotz");
    EXPECT_THAT(key, ContainerEq(p1->dof_key_));
    EXPECT_THAT(key, ContainerEq(p2->dof_key_));
    EXPECT_THAT(key, ContainerEq(p3->dof_key_));

    cout << "check particle mass..." << endl;
    double m1 = mass1 / 2.0;
    double m2 = mass1 / 2.0;
    double m3 = mass2 / 2.0;
    double m4 = mass2 / 2.0;
    double m5 = mass3 / 2.0;
    double m6 = mass3 / 2.0;
    double m7 = mass4 / 2.0 + mass5 / 2.0;
    double m8 = mass4 / 2.0;
    double m9 = mass5 / 2.0;

    StdArray6d pm1 = {m1, m1, 0, 0, 0, m1 * e1->section_->Izz_ / e1->section_->A_};
    StdArray6d pm2 = {m2, m2, 0, 0, 0, m2 * e1->section_->Izz_ / e1->section_->A_};
    StdArray6d pm3 = {m3, m3, 0, 0, 0, m3 * e2->section_->Izz_ / e2->section_->A_};
    StdArray6d pm4 = {m4, m4, 0, 0, 0, m4 * e2->section_->Izz_ / e2->section_->A_};
    StdArray6d pm5 = {m5, m5, 0, 0, 0, m5 * e3->section_->Izz_ / e3->section_->A_};
    StdArray6d pm6 = {m6, m6, 0, 0, 0, m6 * e3->section_->Izz_ / e3->section_->A_};
    StdArray6d pm7 = {m7, m7, 0, 0, 0,
        (mass4 * e4->section_->Izz_ / e4->section_->A_ +
         mass5 * e5->section_->Izz_ / e5->section_->A_) / 2.0};
    StdArray6d pm8 = {m8, m8, 0, 0, 0, m8 * e4->section_->Izz_ / e4->section_->A_};
    StdArray6d pm9 = {m9, m9, 0, 0, 0, m9 * e5->section_->Izz_ / e5->section_->A_};

    EXPECT_THAT(pm1, ContainerEq(p1->mass()->lumped_mass));
    EXPECT_THAT(pm2, ContainerEq(p2->mass()->lumped_mass));
    EXPECT_THAT(pm3, ContainerEq(p3->mass()->lumped_mass));
    EXPECT_THAT(pm4, ContainerEq(p4->mass()->lumped_mass));
    EXPECT_THAT(pm5, ContainerEq(p5->mass()->lumped_mass));
    EXPECT_THAT(pm6, ContainerEq(p6->mass()->lumped_mass));
    EXPECT_THAT(pm7, ContainerEq(p7->mass()->lumped_mass));
    EXPECT_THAT(pm8, ContainerEq(p8->mass()->lumped_mass));
    EXPECT_THAT(pm9, ContainerEq(p9->mass()->lumped_mass));

    Eigen::Matrix3d m, Im;
    m.setZero();
    Im.setZero();
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            // translate mass matrix
            EXPECT_DOUBLE_EQ(m(i,j), p1->mass()->translate_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(m(i,j), p2->mass()->translate_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(m(i,j), p3->mass()->translate_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(m(i,j), p4->mass()->translate_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(m(i,j), p5->mass()->translate_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(m(i,j), p6->mass()->translate_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(m(i,j), p7->mass()->translate_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(m(i,j), p8->mass()->translate_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(m(i,j), p9->mass()->translate_mass_matrix(i,j));

            // rotation mass matrix
            EXPECT_DOUBLE_EQ(Im(i,j), p1->mass()->rotation_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(Im(i,j), p2->mass()->rotation_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(Im(i,j), p3->mass()->rotation_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(Im(i,j), p4->mass()->rotation_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(Im(i,j), p5->mass()->rotation_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(Im(i,j), p6->mass()->rotation_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(Im(i,j), p7->mass()->rotation_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(Im(i,j), p8->mass()->rotation_mass_matrix(i,j));
            EXPECT_DOUBLE_EQ(Im(i,j), p9->mass()->rotation_mass_matrix(i,j));
        }
    }

    cout << "check translate_mass_matrix_on... " << endl;
    EXPECT_FALSE(p1->translate_mass_matrix_on_);
    EXPECT_FALSE(p2->translate_mass_matrix_on_);
    EXPECT_FALSE(p3->translate_mass_matrix_on_);
    EXPECT_FALSE(p4->translate_mass_matrix_on_);
    EXPECT_FALSE(p5->translate_mass_matrix_on_);
    EXPECT_FALSE(p6->translate_mass_matrix_on_);
    EXPECT_FALSE(p7->translate_mass_matrix_on_);
    EXPECT_FALSE(p8->translate_mass_matrix_on_);
    EXPECT_FALSE(p9->translate_mass_matrix_on_);

    cout << "check rotation_mass_matrix_on... " << endl;
    EXPECT_FALSE(p1->rotation_mass_matrix_on_);
    EXPECT_FALSE(p2->rotation_mass_matrix_on_);
    EXPECT_FALSE(p3->rotation_mass_matrix_on_);
    EXPECT_FALSE(p4->rotation_mass_matrix_on_);
    EXPECT_FALSE(p5->rotation_mass_matrix_on_);
    EXPECT_FALSE(p6->rotation_mass_matrix_on_);
    EXPECT_FALSE(p7->rotation_mass_matrix_on_);
    EXPECT_FALSE(p8->rotation_mass_matrix_on_);
    EXPECT_FALSE(p9->rotation_mass_matrix_on_);

    cout << "check stiffness..." << endl;
    Eigen::Matrix3d stiff;
    stiff << e1->section_->A_, 0, 0,
             0, 4.0*e1->section_->Izz_, 2.0*e1->section_->Izz_,
             0, 2.0*e1->section_->Izz_, 4.0*e1->section_->Izz_;
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            EXPECT_DOUBLE_EQ(stiff(i,j), e1->stiffness_(i,j));
        }
    }

    cout << "check force..." << endl;
    for (int i = 0; i < 3; i++)
    {
        EXPECT_DOUBLE_EQ(0.0, e1->force_(i));
    }

    cout << "check emfi, emfj..." << endl;
    EXPECT_DOUBLE_EQ(0.0, e1->emfi_.fx);
    EXPECT_DOUBLE_EQ(0.0, e1->emfi_.sfy);
    EXPECT_DOUBLE_EQ(0.0, e1->emfi_.sfz);
    EXPECT_DOUBLE_EQ(0.0, e1->emfi_.mx);
    EXPECT_DOUBLE_EQ(0.0, e1->emfi_.my);
    EXPECT_DOUBLE_EQ(0.0, e1->emfi_.mz);
    EXPECT_DOUBLE_EQ(0.0, e1->emfj_.fx);
    EXPECT_DOUBLE_EQ(0.0, e1->emfj_.sfy);
    EXPECT_DOUBLE_EQ(0.0, e1->emfj_.sfz);
    EXPECT_DOUBLE_EQ(0.0, e1->emfj_.mx);
    EXPECT_DOUBLE_EQ(0.0, e1->emfj_.my);
    EXPECT_DOUBLE_EQ(0.0, e1->emfj_.mz);

    cout << "check exa, eya, eza..." << endl;
    EXPECT_DOUBLE_EQ(1.0, e1->exa_(0));
    EXPECT_DOUBLE_EQ(0.0, e1->exa_(1));
    EXPECT_DOUBLE_EQ(0.0, e1->exa_(2));

    for (int i = 0; i < 3; i++)
    {
        EXPECT_DOUBLE_EQ(0.0, e1->eya_(i));
        EXPECT_DOUBLE_EQ(0.0, e1->eza_(i));
    }

    cout << "check elem_rotation_angle..." << endl;
    EXPECT_DOUBLE_EQ(0.0, e1->elem_rotation_angle_);
};


// Tests that the Beam2D::calcElementForce() method.
TEST_F(Beam2DTest, CalcElementForce)
{
    // hard to evaluate, several test cases have been write to check the
    // function of Beam2D. (included in ../beam)
    // test_cantilever_beam2d.cpp
    // test_simply_supported_beam2d.cpp
};


// // Tests that the Beam2D::calcCriticalTimeStep() method.
// TEST_F(Beam2DTest, calcCriticalTimeStep)
// {
//     cout << "\n---------- Beam2DTest calcCriticalTimeStep ----------" << endl;
//     double dt = e1->length_ / e1->material_->calcCe();
//     EXPECT_DOUBLE_EQ(dt, e1->calcCriticalTimeStep());
// };



